{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Client Server in Concrete ML\n",
    "\n",
    "Concrete ML allows machine learning practitioners to build FHE models. In this notebook, we present a simple case where a model is sent to a server to predict over encrypted data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import platform\n",
    "import time\n",
    "from shutil import copyfile\n",
    "from tempfile import TemporaryDirectory\n",
    "\n",
    "import numpy\n",
    "from sklearn.datasets import load_breast_cancer\n",
    "\n",
    "from concrete.ml.deployment import FHEModelClient, FHEModelDev, FHEModelServer\n",
    "from concrete.ml.sklearn import XGBClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class OnDiskNetwork:\n",
    "    \"\"\"Simulate a network on disk.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        # Create 3 temporary folder for server, client and dev with tempfile\n",
    "        self.server_dir = TemporaryDirectory()  # pylint: disable=consider-using-with\n",
    "        self.client_dir = TemporaryDirectory()  # pylint: disable=consider-using-with\n",
    "        self.dev_dir = TemporaryDirectory()  # pylint: disable=consider-using-with\n",
    "\n",
    "    def client_send_evaluation_key_to_server(self, serialized_evaluation_keys):\n",
    "        \"\"\"Send the public key to the server.\"\"\"\n",
    "        with open(self.server_dir.name + \"/serialized_evaluation_keys.ekl\", \"wb\") as f:\n",
    "            f.write(serialized_evaluation_keys)\n",
    "\n",
    "    def client_send_input_to_server_for_prediction(self, encrypted_input):\n",
    "        \"\"\"Send the input to the server and execute on the server in FHE.\"\"\"\n",
    "        with open(self.server_dir.name + \"/serialized_evaluation_keys.ekl\", \"rb\") as f:\n",
    "            serialized_evaluation_keys = f.read()\n",
    "        time_begin = time.time()\n",
    "        encrypted_prediction = FHEModelServer(self.server_dir.name).run(\n",
    "            encrypted_input, serialized_evaluation_keys\n",
    "        )\n",
    "        time_end = time.time()\n",
    "        with open(self.server_dir.name + \"/encrypted_prediction.enc\", \"wb\") as f:\n",
    "            f.write(encrypted_prediction)\n",
    "        return time_end - time_begin\n",
    "\n",
    "    def dev_send_model_to_server(self):\n",
    "        \"\"\"Send the model to the server.\"\"\"\n",
    "        copyfile(self.dev_dir.name + \"/server.zip\", self.server_dir.name + \"/server.zip\")\n",
    "\n",
    "    def server_send_encrypted_prediction_to_client(self):\n",
    "        \"\"\"Send the encrypted prediction to the client.\"\"\"\n",
    "        with open(self.server_dir.name + \"/encrypted_prediction.enc\", \"rb\") as f:\n",
    "            encrypted_prediction = f.read()\n",
    "        return encrypted_prediction\n",
    "\n",
    "    def dev_send_clientspecs_and_modelspecs_to_client(self):\n",
    "        \"\"\"Send the clientspecs and evaluation key to the client.\"\"\"\n",
    "        copyfile(self.dev_dir.name + \"/client.zip\", self.client_dir.name + \"/client.zip\")\n",
    "\n",
    "    def cleanup(self):\n",
    "        \"\"\"Clean up the temporary folders.\"\"\"\n",
    "        self.server_dir.cleanup()\n",
    "        self.client_dir.cleanup()\n",
    "        self.dev_dir.cleanup()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Development Machine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model trained and compiled.\n"
     ]
    }
   ],
   "source": [
    "from concrete.compiler import check_gpu_available\n",
    "\n",
    "# Let's first get some data and train a model.\n",
    "X, y = load_breast_cancer(return_X_y=True)\n",
    "\n",
    "# Split X into X_model_owner and X_client\n",
    "X_model_owner, X_client = X[:-10], X[-10:]\n",
    "y_model_owner, y_client = y[:-10], y[-10:]\n",
    "\n",
    "# Some issues on macOS, if too many estimators\n",
    "n_estimators = 10\n",
    "if platform.system() == \"Darwin\":\n",
    "    n_estimators = 9\n",
    "\n",
    "\n",
    "use_gpu_if_available = False\n",
    "device = \"cuda\" if use_gpu_if_available and check_gpu_available() else \"cpu\"\n",
    "\n",
    "# Train the model and compile it\n",
    "model_dev = XGBClassifier(n_bits=2, n_estimators=n_estimators, max_depth=3)\n",
    "model_dev.fit(X_model_owner, y_model_owner)\n",
    "model_dev.compile(X_model_owner, device=device)\n",
    "\n",
    "print(\"Model trained and compiled.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's instantiate the network\n",
    "network = OnDiskNetwork()\n",
    "\n",
    "# Now that the model has been trained we want to save it to send it to a server\n",
    "fhemodel_dev = FHEModelDev(network.dev_dir.name, model_dev)\n",
    "fhemodel_dev.save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 64K\r\n",
      "-rw-r--r-- 1 root root  59K Sep  5 14:28 client.zip\r\n",
      "-rw-r--r-- 1 root root 2.6K Sep  5 14:28 server.zip\r\n"
     ]
    }
   ],
   "source": [
    "# Print all files in the temporary directory along with their sizes in KB\n",
    "!ls -lh $network.dev_dir.name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the folder, we can see 3 different files:\n",
    " - client.zip - contains cryptographic parameters to be sent to the client for them to create the keys (can easily be served via HTTP request). also contains the description of the pre-processing and post-processing that must be applied before encryption and after decryption.\n",
    " - server.zip - contains everything required to do homomorphic prediction.\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 4.0K\r\n",
      "-rw-r--r-- 1 root root 2.6K Sep  5 14:28 server.zip\r\n"
     ]
    }
   ],
   "source": [
    "# Let's send the model to the server\n",
    "network.dev_send_model_to_server()\n",
    "!ls -lh $network.server_dir.name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 60K\r\n",
      "-rw-r--r-- 1 root root 59K Sep  5 14:28 client.zip\r\n"
     ]
    }
   ],
   "source": [
    "# Let's send the clientspecs and evaluation key to the client\n",
    "network.dev_send_clientspecs_and_modelspecs_to_client()\n",
    "!ls -lh $network.client_dir.name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Client Machine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "KeySetCache: miss, regenerating /tmp/tmpsf57p3ed/3497183917896914639\n"
     ]
    }
   ],
   "source": [
    "# Let's create the client and load the model\n",
    "fhemodel_client = FHEModelClient(network.client_dir.name, key_dir=network.client_dir.name)\n",
    "\n",
    "# The client first need to create the private and evaluation keys.\n",
    "serialized_evaluation_keys = fhemodel_client.get_serialized_evaluation_keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluation keys size: 29.14 MB\n"
     ]
    }
   ],
   "source": [
    "# Evaluation keys can be quite large files but only have to be shared once with the server.\n",
    "\n",
    "# Check the size of the evaluation keys (in MB)\n",
    "print(f\"Evaluation keys size: {len(serialized_evaluation_keys) / (10**6):.2f} MB\")\n",
    "\n",
    "# Let's send this evaluation key to the server (this has to be done only once)\n",
    "network.client_send_evaluation_key_to_server(serialized_evaluation_keys)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Encrypted data is 4.10 times larger than the clear data\n",
      "The average execution time is 0.62 seconds per sample.\n"
     ]
    }
   ],
   "source": [
    "# Now we have everything for the client to interact with the server\n",
    "\n",
    "# We create a loop to send the input to the server and receive the encrypted prediction\n",
    "decrypted_predictions = []\n",
    "execution_time = []\n",
    "for i in range(X_client.shape[0]):\n",
    "    clear_input = X_client[[i], :]\n",
    "    encrypted_input = fhemodel_client.quantize_encrypt_serialize(clear_input)\n",
    "    execution_time += [network.client_send_input_to_server_for_prediction(encrypted_input)]\n",
    "    encrypted_prediction = network.server_send_encrypted_prediction_to_client()\n",
    "    decrypted_prediction = fhemodel_client.deserialize_decrypt_dequantize(encrypted_prediction)[0]\n",
    "    decrypted_predictions.append(decrypted_prediction)\n",
    "\n",
    "# Check MB size with sys of the encrypted data vs clear data\n",
    "print(\n",
    "    f\"Encrypted data is \"\n",
    "    f\"{len(encrypted_input)/clear_input.nbytes:.2f}\"\n",
    "    \" times larger than the clear data\"\n",
    ")\n",
    "\n",
    "# Show execution time\n",
    "print(f\"The average execution time is {numpy.mean(execution_time):.2f} seconds per sample.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy between FHE prediction and clear model is: 100%\n"
     ]
    }
   ],
   "source": [
    "# Let's check the results and compare them against the clear model\n",
    "clear_prediction_classes = model_dev.predict_proba(X_client).argmax(axis=1)\n",
    "decrypted_predictions_classes = numpy.array(decrypted_predictions).argmax(axis=1)\n",
    "accuracy = (clear_prediction_classes == decrypted_predictions_classes).mean()\n",
    "print(f\"Accuracy between FHE prediction and clear model is: {accuracy*100:.0f}%\")"
   ]
  }
 ],
 "metadata": {
  "execution": {
   "timeout": 10800
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
